{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ProtoNN in Tensorflow\n",
    "\n",
    "This is a simple notebook that illustrates the usage of Tensorflow implementation of ProtoNN. We are using the USPS dataset. Please refer to `fetch_usps.py` and `process_usps.py`for more details on downloading the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-08-15T13:06:10.223951Z",
     "start_time": "2018-08-15T13:06:09.303454Z"
    }
   },
   "outputs": [],
   "source": [
    "# Copyright (c) Microsoft Corporation. All rights reserved.\n",
    "# Licensed under the MIT license.\n",
    "\n",
    "from __future__ import print_function\n",
    "import sys\n",
    "import os\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "from edgeml_tf.trainer.protoNNTrainer import ProtoNNTrainer\n",
    "from edgeml_tf.graph.protoNN import ProtoNN\n",
    "import edgeml_tf.utils as utils\n",
    "import helpermethods as helper"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# USPS Data\n",
    "\n",
    "It is assumed that the USPS data has already been downloaded and set up with the help of [fetch_usps.py](fetch_usps.py) and is placed in the `./usps10` subdirectory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-08-15T13:06:10.271026Z",
     "start_time": "2018-08-15T13:06:10.225900Z"
    }
   },
   "outputs": [],
   "source": [
    "# Load data\n",
    "DATA_DIR = './usps10'\n",
    "train, test = np.load(DATA_DIR + '/train.npy'), np.load(DATA_DIR + '/test.npy')\n",
    "x_train, y_train = train[:, 1:], train[:, 0]\n",
    "x_test, y_test = test[:, 1:], test[:, 0]\n",
    "\n",
    "numClasses = max(y_train) - min(y_train) + 1\n",
    "numClasses = max(numClasses, max(y_test) - min(y_test) + 1)\n",
    "numClasses = int(numClasses)\n",
    "\n",
    "y_train = helper.to_onehot(y_train, numClasses)\n",
    "y_test = helper.to_onehot(y_test, numClasses)\n",
    "dataDimension = x_train.shape[1]\n",
    "numClasses = y_train.shape[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Parameters\n",
    "\n",
    "Note that ProtoNN is very sensitive to the value of the hyperparameter $\\gamma$, here stored in valiable `GAMMA`. If `GAMMA` is set to `None`, median heuristic will be used to estimate a good value of $\\gamma$ through the `helper.getGamma()` method. This method also returns the corresponding `W` and `B` matrices which should be used to initialize ProtoNN (as is done here)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-08-15T13:06:10.279204Z",
     "start_time": "2018-08-15T13:06:10.272880Z"
    }
   },
   "outputs": [],
   "source": [
    "PROJECTION_DIM = 60\n",
    "NUM_PROTOTYPES = 60\n",
    "REG_W = 0.000005\n",
    "REG_B = 0.0\n",
    "REG_Z = 0.00005\n",
    "SPAR_W = 0.8\n",
    "SPAR_B = 1.0\n",
    "SPAR_Z = 1.0\n",
    "LEARNING_RATE = 0.05\n",
    "NUM_EPOCHS = 200\n",
    "BATCH_SIZE = 32\n",
    "GAMMA = 0.0015"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-08-15T13:06:10.307632Z",
     "start_time": "2018-08-15T13:06:10.280955Z"
    }
   },
   "outputs": [],
   "source": [
    "W, B, gamma = helper.getGamma(GAMMA, PROJECTION_DIM, dataDimension,\n",
    "                       NUM_PROTOTYPES, x_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-08-15T13:07:22.641991Z",
     "start_time": "2018-08-15T13:06:10.309353Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch:   0 Batch:   0 Loss: 5.85158 Accuracy: 0.03125\n",
      "Epoch:   1 Batch:   0 Loss: 1.53823 Accuracy: 0.65625\n",
      "Epoch:   2 Batch:   0 Loss: 0.81371 Accuracy: 0.87500\n",
      "Epoch:   3 Batch:   0 Loss: 0.51246 Accuracy: 0.87500\n",
      "Epoch:   4 Batch:   0 Loss: 0.41875 Accuracy: 0.93750\n",
      "Epoch:   5 Batch:   0 Loss: 0.36797 Accuracy: 0.96875\n",
      "Epoch:   6 Batch:   0 Loss: 0.32868 Accuracy: 0.96875\n",
      "Epoch:   7 Batch:   0 Loss: 0.30316 Accuracy: 0.96875\n",
      "Epoch:   8 Batch:   0 Loss: 0.29075 Accuracy: 0.96875\n",
      "Epoch:   9 Batch:   0 Loss: 0.28370 Accuracy: 0.96875\n",
      "Test Loss: 0.50615 Accuracy: 0.89497\n",
      "Epoch:  10 Batch:   0 Loss: 0.28014 Accuracy: 0.96875\n",
      "Epoch:  11 Batch:   0 Loss: 0.27734 Accuracy: 0.96875\n",
      "Epoch:  12 Batch:   0 Loss: 0.27511 Accuracy: 0.96875\n",
      "Epoch:  13 Batch:   0 Loss: 0.27126 Accuracy: 0.96875\n",
      "Epoch:  14 Batch:   0 Loss: 0.26776 Accuracy: 0.96875\n",
      "Epoch:  15 Batch:   0 Loss: 0.26506 Accuracy: 0.96875\n",
      "Epoch:  16 Batch:   0 Loss: 0.26371 Accuracy: 0.96875\n",
      "Epoch:  17 Batch:   0 Loss: 0.26249 Accuracy: 0.96875\n",
      "Epoch:  18 Batch:   0 Loss: 0.26094 Accuracy: 0.96875\n",
      "Epoch:  19 Batch:   0 Loss: 0.25879 Accuracy: 0.96875\n",
      "Test Loss: 0.54362 Accuracy: 0.89494\n",
      "Epoch:  20 Batch:   0 Loss: 0.25642 Accuracy: 0.96875\n",
      "Epoch:  21 Batch:   0 Loss: 0.25328 Accuracy: 0.96875\n",
      "Epoch:  22 Batch:   0 Loss: 0.25015 Accuracy: 0.96875\n",
      "Epoch:  23 Batch:   0 Loss: 0.24684 Accuracy: 0.96875\n",
      "Epoch:  24 Batch:   0 Loss: 0.24365 Accuracy: 0.96875\n",
      "Epoch:  25 Batch:   0 Loss: 0.24023 Accuracy: 0.96875\n",
      "Epoch:  26 Batch:   0 Loss: 0.23747 Accuracy: 0.96875\n",
      "Epoch:  27 Batch:   0 Loss: 0.23460 Accuracy: 0.96875\n",
      "Epoch:  28 Batch:   0 Loss: 0.23170 Accuracy: 0.96875\n",
      "Epoch:  29 Batch:   0 Loss: 0.22903 Accuracy: 0.96875\n",
      "Test Loss: 0.54884 Accuracy: 0.89391\n",
      "Epoch:  30 Batch:   0 Loss: 0.22662 Accuracy: 0.96875\n",
      "Epoch:  31 Batch:   0 Loss: 0.22448 Accuracy: 0.96875\n",
      "Epoch:  32 Batch:   0 Loss: 0.22245 Accuracy: 0.96875\n",
      "Epoch:  33 Batch:   0 Loss: 0.22068 Accuracy: 0.96875\n",
      "Epoch:  34 Batch:   0 Loss: 0.21904 Accuracy: 0.96875\n",
      "Epoch:  35 Batch:   0 Loss: 0.21723 Accuracy: 0.96875\n",
      "Epoch:  36 Batch:   0 Loss: 0.21582 Accuracy: 0.96875\n",
      "Epoch:  37 Batch:   0 Loss: 0.21409 Accuracy: 0.96875\n",
      "Epoch:  38 Batch:   0 Loss: 0.21246 Accuracy: 0.96875\n",
      "Epoch:  39 Batch:   0 Loss: 0.21095 Accuracy: 0.96875\n",
      "Test Loss: 0.52917 Accuracy: 0.90091\n",
      "Epoch:  40 Batch:   0 Loss: 0.20928 Accuracy: 0.96875\n",
      "Epoch:  41 Batch:   0 Loss: 0.20770 Accuracy: 0.96875\n",
      "Epoch:  42 Batch:   0 Loss: 0.20633 Accuracy: 0.96875\n",
      "Epoch:  43 Batch:   0 Loss: 0.20512 Accuracy: 0.96875\n",
      "Epoch:  44 Batch:   0 Loss: 0.20377 Accuracy: 0.96875\n",
      "Epoch:  45 Batch:   0 Loss: 0.20240 Accuracy: 0.96875\n",
      "Epoch:  46 Batch:   0 Loss: 0.20124 Accuracy: 0.96875\n",
      "Epoch:  47 Batch:   0 Loss: 0.20002 Accuracy: 0.96875\n",
      "Epoch:  48 Batch:   0 Loss: 0.19910 Accuracy: 0.96875\n",
      "Epoch:  49 Batch:   0 Loss: 0.19808 Accuracy: 0.96875\n",
      "Test Loss: 0.50988 Accuracy: 0.90292\n",
      "Epoch:  50 Batch:   0 Loss: 0.19705 Accuracy: 0.96875\n",
      "Epoch:  51 Batch:   0 Loss: 0.19629 Accuracy: 1.00000\n",
      "Epoch:  52 Batch:   0 Loss: 0.19560 Accuracy: 1.00000\n",
      "Epoch:  53 Batch:   0 Loss: 0.19483 Accuracy: 1.00000\n",
      "Epoch:  54 Batch:   0 Loss: 0.19404 Accuracy: 1.00000\n",
      "Epoch:  55 Batch:   0 Loss: 0.19351 Accuracy: 1.00000\n",
      "Epoch:  56 Batch:   0 Loss: 0.19279 Accuracy: 1.00000\n",
      "Epoch:  57 Batch:   0 Loss: 0.19250 Accuracy: 1.00000\n",
      "Epoch:  58 Batch:   0 Loss: 0.19207 Accuracy: 1.00000\n",
      "Epoch:  59 Batch:   0 Loss: 0.19169 Accuracy: 1.00000\n",
      "Test Loss: 0.48988 Accuracy: 0.90443\n",
      "Epoch:  60 Batch:   0 Loss: 0.19146 Accuracy: 1.00000\n",
      "Epoch:  61 Batch:   0 Loss: 0.19119 Accuracy: 1.00000\n",
      "Epoch:  62 Batch:   0 Loss: 0.19095 Accuracy: 1.00000\n",
      "Epoch:  63 Batch:   0 Loss: 0.19077 Accuracy: 1.00000\n",
      "Epoch:  64 Batch:   0 Loss: 0.19066 Accuracy: 1.00000\n",
      "Epoch:  65 Batch:   0 Loss: 0.19071 Accuracy: 1.00000\n",
      "Epoch:  66 Batch:   0 Loss: 0.19066 Accuracy: 1.00000\n",
      "Epoch:  67 Batch:   0 Loss: 0.19071 Accuracy: 1.00000\n",
      "Epoch:  68 Batch:   0 Loss: 0.19083 Accuracy: 1.00000\n",
      "Epoch:  69 Batch:   0 Loss: 0.19090 Accuracy: 1.00000\n",
      "Test Loss: 0.47286 Accuracy: 0.90841\n",
      "Epoch:  70 Batch:   0 Loss: 0.19108 Accuracy: 1.00000\n",
      "Epoch:  71 Batch:   0 Loss: 0.19110 Accuracy: 1.00000\n",
      "Epoch:  72 Batch:   0 Loss: 0.19109 Accuracy: 1.00000\n",
      "Epoch:  73 Batch:   0 Loss: 0.19118 Accuracy: 1.00000\n",
      "Epoch:  74 Batch:   0 Loss: 0.19115 Accuracy: 1.00000\n",
      "Epoch:  75 Batch:   0 Loss: 0.19122 Accuracy: 1.00000\n",
      "Epoch:  76 Batch:   0 Loss: 0.19099 Accuracy: 1.00000\n",
      "Epoch:  77 Batch:   0 Loss: 0.19087 Accuracy: 1.00000\n",
      "Epoch:  78 Batch:   0 Loss: 0.19070 Accuracy: 1.00000\n",
      "Epoch:  79 Batch:   0 Loss: 0.19069 Accuracy: 0.96875\n",
      "Test Loss: 0.46010 Accuracy: 0.91289\n",
      "Epoch:  80 Batch:   0 Loss: 0.19055 Accuracy: 0.96875\n",
      "Epoch:  81 Batch:   0 Loss: 0.19055 Accuracy: 0.96875\n",
      "Epoch:  82 Batch:   0 Loss: 0.19013 Accuracy: 0.96875\n",
      "Epoch:  83 Batch:   0 Loss: 0.19005 Accuracy: 0.96875\n",
      "Epoch:  84 Batch:   0 Loss: 0.18991 Accuracy: 0.96875\n",
      "Epoch:  85 Batch:   0 Loss: 0.18985 Accuracy: 0.96875\n",
      "Epoch:  86 Batch:   0 Loss: 0.18961 Accuracy: 0.96875\n",
      "Epoch:  87 Batch:   0 Loss: 0.18926 Accuracy: 0.96875\n",
      "Epoch:  88 Batch:   0 Loss: 0.18901 Accuracy: 0.96875\n",
      "Epoch:  89 Batch:   0 Loss: 0.18866 Accuracy: 0.96875\n",
      "Test Loss: 0.45069 Accuracy: 0.91588\n",
      "Epoch:  90 Batch:   0 Loss: 0.18821 Accuracy: 0.96875\n",
      "Epoch:  91 Batch:   0 Loss: 0.18801 Accuracy: 0.96875\n",
      "Epoch:  92 Batch:   0 Loss: 0.18799 Accuracy: 0.96875\n",
      "Epoch:  93 Batch:   0 Loss: 0.18779 Accuracy: 0.96875\n",
      "Epoch:  94 Batch:   0 Loss: 0.18743 Accuracy: 0.96875\n",
      "Epoch:  95 Batch:   0 Loss: 0.18732 Accuracy: 0.96875\n",
      "Epoch:  96 Batch:   0 Loss: 0.18720 Accuracy: 0.96875\n",
      "Epoch:  97 Batch:   0 Loss: 0.18696 Accuracy: 0.96875\n",
      "Epoch:  98 Batch:   0 Loss: 0.18674 Accuracy: 0.96875\n",
      "Epoch:  99 Batch:   0 Loss: 0.18637 Accuracy: 0.96875\n",
      "Test Loss: 0.44313 Accuracy: 0.91739\n",
      "Epoch: 100 Batch:   0 Loss: 0.18625 Accuracy: 0.96875\n",
      "Epoch: 101 Batch:   0 Loss: 0.18607 Accuracy: 0.96875\n",
      "Epoch: 102 Batch:   0 Loss: 0.18596 Accuracy: 0.96875\n",
      "Epoch: 103 Batch:   0 Loss: 0.18585 Accuracy: 0.96875\n",
      "Epoch: 104 Batch:   0 Loss: 0.18578 Accuracy: 0.96875\n",
      "Epoch: 105 Batch:   0 Loss: 0.18561 Accuracy: 0.96875\n",
      "Epoch: 106 Batch:   0 Loss: 0.18538 Accuracy: 0.96875\n",
      "Epoch: 107 Batch:   0 Loss: 0.18525 Accuracy: 0.96875\n",
      "Epoch: 108 Batch:   0 Loss: 0.18512 Accuracy: 0.96875\n",
      "Epoch: 109 Batch:   0 Loss: 0.18494 Accuracy: 0.96875\n",
      "Test Loss: 0.43641 Accuracy: 0.91939\n",
      "Epoch: 110 Batch:   0 Loss: 0.18505 Accuracy: 0.96875\n",
      "Epoch: 111 Batch:   0 Loss: 0.18486 Accuracy: 0.96875\n",
      "Epoch: 112 Batch:   0 Loss: 0.18481 Accuracy: 0.96875\n",
      "Epoch: 113 Batch:   0 Loss: 0.18461 Accuracy: 0.96875\n",
      "Epoch: 114 Batch:   0 Loss: 0.18439 Accuracy: 0.96875\n",
      "Epoch: 115 Batch:   0 Loss: 0.18411 Accuracy: 0.96875\n",
      "Epoch: 116 Batch:   0 Loss: 0.18385 Accuracy: 0.96875\n",
      "Epoch: 117 Batch:   0 Loss: 0.18367 Accuracy: 0.96875\n",
      "Epoch: 118 Batch:   0 Loss: 0.18353 Accuracy: 0.96875\n",
      "Epoch: 119 Batch:   0 Loss: 0.18333 Accuracy: 0.96875\n",
      "Test Loss: 0.43042 Accuracy: 0.92185\n",
      "Epoch: 120 Batch:   0 Loss: 0.18315 Accuracy: 0.96875\n",
      "Epoch: 121 Batch:   0 Loss: 0.18304 Accuracy: 0.96875\n",
      "Epoch: 122 Batch:   0 Loss: 0.18268 Accuracy: 0.96875\n",
      "Epoch: 123 Batch:   0 Loss: 0.18227 Accuracy: 0.96875\n",
      "Epoch: 124 Batch:   0 Loss: 0.18209 Accuracy: 0.96875\n",
      "Epoch: 125 Batch:   0 Loss: 0.18206 Accuracy: 0.96875\n",
      "Epoch: 126 Batch:   0 Loss: 0.18184 Accuracy: 0.96875\n",
      "Epoch: 127 Batch:   0 Loss: 0.18185 Accuracy: 0.96875\n",
      "Epoch: 128 Batch:   0 Loss: 0.18159 Accuracy: 0.96875\n",
      "Epoch: 129 Batch:   0 Loss: 0.18152 Accuracy: 0.96875\n",
      "Test Loss: 0.42608 Accuracy: 0.92284\n",
      "Epoch: 130 Batch:   0 Loss: 0.18125 Accuracy: 0.96875\n",
      "Epoch: 131 Batch:   0 Loss: 0.18102 Accuracy: 0.96875\n",
      "Epoch: 132 Batch:   0 Loss: 0.18075 Accuracy: 0.96875\n",
      "Epoch: 133 Batch:   0 Loss: 0.18055 Accuracy: 0.96875\n",
      "Epoch: 134 Batch:   0 Loss: 0.18025 Accuracy: 0.96875\n",
      "Epoch: 135 Batch:   0 Loss: 0.18015 Accuracy: 0.96875\n",
      "Epoch: 136 Batch:   0 Loss: 0.18000 Accuracy: 0.96875\n",
      "Epoch: 137 Batch:   0 Loss: 0.17977 Accuracy: 0.96875\n",
      "Epoch: 138 Batch:   0 Loss: 0.17963 Accuracy: 0.96875\n",
      "Epoch: 139 Batch:   0 Loss: 0.17950 Accuracy: 0.96875\n",
      "Test Loss: 0.42305 Accuracy: 0.92238\n",
      "Epoch: 140 Batch:   0 Loss: 0.17935 Accuracy: 0.96875\n",
      "Epoch: 141 Batch:   0 Loss: 0.17908 Accuracy: 0.96875\n",
      "Epoch: 142 Batch:   0 Loss: 0.17903 Accuracy: 0.96875\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 143 Batch:   0 Loss: 0.17900 Accuracy: 0.96875\n",
      "Epoch: 144 Batch:   0 Loss: 0.17877 Accuracy: 0.96875\n",
      "Epoch: 145 Batch:   0 Loss: 0.17863 Accuracy: 0.96875\n",
      "Epoch: 146 Batch:   0 Loss: 0.17844 Accuracy: 0.96875\n",
      "Epoch: 147 Batch:   0 Loss: 0.17825 Accuracy: 1.00000\n",
      "Epoch: 148 Batch:   0 Loss: 0.17808 Accuracy: 1.00000\n",
      "Epoch: 149 Batch:   0 Loss: 0.17798 Accuracy: 1.00000\n",
      "Test Loss: 0.42054 Accuracy: 0.92288\n",
      "Epoch: 150 Batch:   0 Loss: 0.17788 Accuracy: 1.00000\n",
      "Epoch: 151 Batch:   0 Loss: 0.17773 Accuracy: 1.00000\n",
      "Epoch: 152 Batch:   0 Loss: 0.17753 Accuracy: 1.00000\n",
      "Epoch: 153 Batch:   0 Loss: 0.17743 Accuracy: 1.00000\n",
      "Epoch: 154 Batch:   0 Loss: 0.17736 Accuracy: 1.00000\n",
      "Epoch: 155 Batch:   0 Loss: 0.17724 Accuracy: 1.00000\n",
      "Epoch: 156 Batch:   0 Loss: 0.17721 Accuracy: 1.00000\n",
      "Epoch: 157 Batch:   0 Loss: 0.17721 Accuracy: 1.00000\n",
      "Epoch: 158 Batch:   0 Loss: 0.17713 Accuracy: 1.00000\n",
      "Epoch: 159 Batch:   0 Loss: 0.17701 Accuracy: 1.00000\n",
      "Test Loss: 0.41836 Accuracy: 0.92337\n",
      "Epoch: 160 Batch:   0 Loss: 0.17695 Accuracy: 1.00000\n",
      "Epoch: 161 Batch:   0 Loss: 0.17691 Accuracy: 1.00000\n",
      "Epoch: 162 Batch:   0 Loss: 0.17687 Accuracy: 1.00000\n",
      "Epoch: 163 Batch:   0 Loss: 0.17687 Accuracy: 1.00000\n",
      "Epoch: 164 Batch:   0 Loss: 0.17682 Accuracy: 1.00000\n",
      "Epoch: 165 Batch:   0 Loss: 0.17686 Accuracy: 1.00000\n",
      "Epoch: 166 Batch:   0 Loss: 0.17686 Accuracy: 1.00000\n",
      "Epoch: 167 Batch:   0 Loss: 0.17687 Accuracy: 1.00000\n",
      "Epoch: 168 Batch:   0 Loss: 0.17688 Accuracy: 1.00000\n",
      "Epoch: 169 Batch:   0 Loss: 0.17688 Accuracy: 1.00000\n",
      "Test Loss: 0.41637 Accuracy: 0.92536\n",
      "Epoch: 170 Batch:   0 Loss: 0.17685 Accuracy: 1.00000\n",
      "Epoch: 171 Batch:   0 Loss: 0.17684 Accuracy: 1.00000\n",
      "Epoch: 172 Batch:   0 Loss: 0.17688 Accuracy: 1.00000\n",
      "Epoch: 173 Batch:   0 Loss: 0.17695 Accuracy: 1.00000\n",
      "Epoch: 174 Batch:   0 Loss: 0.17693 Accuracy: 1.00000\n",
      "Epoch: 175 Batch:   0 Loss: 0.17693 Accuracy: 1.00000\n",
      "Epoch: 176 Batch:   0 Loss: 0.17697 Accuracy: 1.00000\n",
      "Epoch: 177 Batch:   0 Loss: 0.17707 Accuracy: 1.00000\n",
      "Epoch: 178 Batch:   0 Loss: 0.17715 Accuracy: 1.00000\n",
      "Epoch: 179 Batch:   0 Loss: 0.17728 Accuracy: 1.00000\n",
      "Test Loss: 0.41443 Accuracy: 0.92585\n",
      "Epoch: 180 Batch:   0 Loss: 0.17732 Accuracy: 1.00000\n",
      "Epoch: 181 Batch:   0 Loss: 0.17738 Accuracy: 1.00000\n",
      "Epoch: 182 Batch:   0 Loss: 0.17747 Accuracy: 1.00000\n",
      "Epoch: 183 Batch:   0 Loss: 0.17750 Accuracy: 1.00000\n",
      "Epoch: 184 Batch:   0 Loss: 0.17762 Accuracy: 1.00000\n",
      "Epoch: 185 Batch:   0 Loss: 0.17775 Accuracy: 1.00000\n",
      "Epoch: 186 Batch:   0 Loss: 0.17791 Accuracy: 1.00000\n",
      "Epoch: 187 Batch:   0 Loss: 0.17795 Accuracy: 1.00000\n",
      "Epoch: 188 Batch:   0 Loss: 0.17811 Accuracy: 1.00000\n",
      "Epoch: 189 Batch:   0 Loss: 0.17818 Accuracy: 1.00000\n",
      "Test Loss: 0.41260 Accuracy: 0.92536\n",
      "Epoch: 190 Batch:   0 Loss: 0.17835 Accuracy: 1.00000\n",
      "Epoch: 191 Batch:   0 Loss: 0.17847 Accuracy: 1.00000\n",
      "Epoch: 192 Batch:   0 Loss: 0.17855 Accuracy: 1.00000\n",
      "Epoch: 193 Batch:   0 Loss: 0.17863 Accuracy: 1.00000\n",
      "Epoch: 194 Batch:   0 Loss: 0.17880 Accuracy: 1.00000\n",
      "Epoch: 195 Batch:   0 Loss: 0.17885 Accuracy: 1.00000\n",
      "Epoch: 196 Batch:   0 Loss: 0.17900 Accuracy: 1.00000\n",
      "Epoch: 197 Batch:   0 Loss: 0.17912 Accuracy: 1.00000\n",
      "Epoch: 198 Batch:   0 Loss: 0.17927 Accuracy: 1.00000\n",
      "Epoch: 199 Batch:   0 Loss: 0.17948 Accuracy: 1.00000\n",
      "Test Loss: 0.41087 Accuracy: 0.92536\n"
     ]
    }
   ],
   "source": [
    "# Setup input and train protoNN\n",
    "X = tf.placeholder(tf.float32, [None, dataDimension], name='X')\n",
    "Y = tf.placeholder(tf.float32, [None, numClasses], name='Y')\n",
    "protoNN = ProtoNN(dataDimension, PROJECTION_DIM,\n",
    "                  NUM_PROTOTYPES, numClasses,\n",
    "                  gamma, W=W, B=B)\n",
    "trainer = ProtoNNTrainer(protoNN, REG_W, REG_B, REG_Z,\n",
    "                         SPAR_W, SPAR_B, SPAR_Z,\n",
    "                         LEARNING_RATE, X, Y, lossType='xentropy')\n",
    "sess = tf.Session()\n",
    "trainer.train(BATCH_SIZE, NUM_EPOCHS, sess, x_train, x_test, y_train, y_test,\n",
    "              printStep=600, valStep=10)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-08-15T13:07:22.671507Z",
     "start_time": "2018-08-15T13:07:22.645050Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final test accuracy 0.92526156\n",
      "Model size constraint (Bytes):  78240\n",
      "Number of non-zeros:  19560\n",
      "Actual model size:  78240\n",
      "Actual non-zeros:  16488\n"
     ]
    }
   ],
   "source": [
    "acc = sess.run(protoNN.accuracy, feed_dict={X: x_test, Y: y_test})\n",
    "# W, B, Z are tensorflow graph nodes\n",
    "W, B, Z, _ = protoNN.getModelMatrices()\n",
    "matrixList = sess.run([W, B, Z])\n",
    "sparcityList = [SPAR_W, SPAR_B, SPAR_Z]\n",
    "nnz, size, sparse = helper.getModelSize(matrixList, sparcityList)\n",
    "print(\"Final test accuracy\", acc)\n",
    "print(\"Model size constraint (Bytes): \", size)\n",
    "print(\"Number of non-zeros: \", nnz)\n",
    "nnz, size, sparse = helper.getModelSize(matrixList, sparcityList, expected=False)\n",
    "print(\"Actual model size: \", size)\n",
    "print(\"Actual non-zeros: \", nnz)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
